import java.util.Scanner;

// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main1 {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int m = in.nextInt();
        int a = in.nextInt();
        int b = in.nextInt();

        int n1 = n;
        int m1 = m;
        long ret1 = 0;
        while(n1 >= 2 && m1 >= 1) {
            ret1 += a;
            n1 -= 2;
            m1 -= 1; 
        }
        if(n1 == 1 & m1 >= 2) {
            ret1 += b;
            n1 -= 1;
            m1 -= 2;
        }
        int n2 = n;
        int m2 = m;
        long ret2 = 0;
        while(n2 >= 1 && m2 >= 2) {
            ret2 += b;
            n2 -= 1;
            m2 -= 2;
        }
        if(m2 == 1 && n2 >= 2) {
            ret2 += a;
            n2 -= 2;
            m2 -= 1;
        }
        System.out.println(Math.max(ret1, ret2));
    }
}